{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cc73d08d-19f6-49aa-a85a-2d2d52c581da",
   "metadata": {},
   "source": [
    "# Import pose estimation model\n",
    "\n",
    "## Define output format\n",
    "\n",
    "Let's load the JSON file, which describes the body format, a slightly modified version of the MS-COCO format (BODY18). This body format is used to create a topology tensor that associate Part Affinity Field (PAF) channels to their corresponding human body part and connect them to generate a skeleton.\n",
    "\n",
    "Reference: [TensorRT-Pose repository](https://github.com/NVIDIA-AI-IOT/trt_pose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7a34bead-4989-4e7b-8c7f-2f09b611375f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['MPLCONFIGDIR'] = os.getcwd() + \"/configs/\" # Specify MatplotLib config folder\n",
    "import numpy as np\n",
    "import json\n",
    "# Requiere https://github.com/NVIDIA-AI-IOT/trt_pose\n",
    "import trt_pose.coco\n",
    "from trt_pose.draw_objects import DrawObjects\n",
    "from trt_pose.parse_objects import ParseObjects\n",
    "\n",
    "with open('human_pose.json', 'r') as f:\n",
    "    human_pose = json.load(f)\n",
    "\n",
    "topology = trt_pose.coco.coco_category_to_topology(human_pose)\n",
    "\n",
    "parse_objects = ParseObjects(topology)\n",
    "draw_objects = DrawObjects(topology)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7c31468-df1c-43c2-bd59-40c50401e2e0",
   "metadata": {},
   "source": [
    "## Import the pose estimation optimized model\n",
    "\n",
    "Next, we'll load our PyTorch pose estimation model. It has been optimized using another Notebook and saved so that we do not need to perform optimization again. The optimization procedure is detailed in the TensorRT-Pose repository. Please note that TensorRT has device-specific optimizations, so you can only use an optimized model on the same platform.\n",
    "\n",
    "Reference: [Torch2TRT repository](https://github.com/NVIDIA-AI-IOT/torch2trt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fc10e8eb-786c-438e-810f-77f069f5ff0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch2trt import TRTModule\n",
    "\n",
    "TRT_POSE_PATH = 'resnet18_baseline_att_224x224_trt_FP16_33MB.pth'\n",
    "\n",
    "model_estimation = TRTModule()\n",
    "model_estimation.load_state_dict(torch.load(TRT_POSE_PATH))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df44a8b7-fe51-470b-86d4-9ec5135f2a49",
   "metadata": {},
   "source": [
    "## Import the pose classification model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "acbd9091-6c13-4456-83b5-a072db26c73a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.python.saved_model import tag_constants\n",
    "\n",
    "TRT_CLASSIFICATION_PATH = './Robust_BODY18_FP32'\n",
    "\n",
    "model_classification = tf.saved_model.load(TRT_CLASSIFICATION_PATH, tags=[tag_constants.SERVING])\n",
    "infer_classification = model_classification.signatures['serving_default']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1689bede-f28c-4533-83f5-bcd52df80e88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dense_20': <tf.Tensor: shape=(1, 20), dtype=float32, numpy=\n",
       " array([[4.0254830e-19, 1.2014065e-29, 7.3594776e-19, 5.5553655e-22,\n",
       "         1.2282521e-17, 2.4289520e-25, 9.9999976e-01, 8.9502602e-27,\n",
       "         2.0831374e-10, 3.7753817e-10, 2.3662060e-19, 2.4505236e-21,\n",
       "         5.7214940e-08, 2.1032307e-07, 1.8473210e-26, 6.7222669e-14,\n",
       "         2.5704211e-26, 1.0280534e-08, 1.4103428e-30, 3.8663297e-15]],\n",
       "       dtype=float32)>}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Blank inference to load model\n",
    "infer_classification(tf.constant(\n",
    "    np.random.normal(size=(1, 18, 2)).astype(np.float32),\n",
    "    dtype=tf.float32,\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2ddf35c3-4304-406c-8c16-e7e5661936ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels: ['Seated', 'Stand', 'Stand_RightArmRaised', 'Stand_LeftArmRaised', 'T', 'MilitarySalute', 'PushUp_Low', 'Squat', 'Plank', 'Yoga_Tree_left', 'Yoga_Tree_right', 'Yoga_UpwardSalute', 'Yoga_Warrior2_left', 'Yoga_Warrior2_right', 'Traffic_AllStop', 'Traffic_BackStop', 'Traffic_FrontStop', 'Traffic_BackFrontStop', 'Traffic_LeftTurn', 'Traffic_RightTurn']\n"
     ]
    }
   ],
   "source": [
    "with open('Robust_BODY18_Info.json') as f:\n",
    "    classificationLabels = json.load(f)['labels']\n",
    "print(\"labels:\", classificationLabels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2dbed1cb-6960-4223-bb9e-fe2671f69e8d",
   "metadata": {},
   "source": [
    "# Define video-processing pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a60b9d20-30cb-4f17-a528-e87b1678a64e",
   "metadata": {},
   "source": [
    "## Access video feed\n",
    "\n",
    "The whole video acquisition pipeline, incuding scaling and cropping, is done using the accelerated GStreamer for NVidia Tegra processors.  See console for details about the video acquisition pipeline. This pipeline is used by OpenCV to access images.\n",
    "\n",
    "References: [User Guide](https://developer.download.nvidia.com/embedded/L4T/r32_Release_v1.0/Docs/Accelerated_GStreamer_User_Guide.pdf?UliDteoP_g5QqgRwKoNbj3abiW9TeMtEWNumYbfdqeWY6oSlJaPISqf04banob6ohwLwYKvWmMUjwI8EWpk3f8lpapB3XvQGRxGPej5eiHmM_QA-AHiAenmymLlFAs1QmtZHTwE4FL_o2GYBqCc1M8ggJJcgb5w6whYwSFe9sK7rp3avYyw), [Documentation](https://docs.nvidia.com/jetson/l4t/index.html#page/Tegra%20Linux%20Driver%20Package%20Development%20Guide/accelerated_gstreamer.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fb73d4c8-63d4-437c-b801-7624935fad61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nvarguscamerasrc ! video/x-raw(memory:NVMM), width=(int)398, height=(int)224, format=(string)NV12, framerate=(fraction)60/1 ! nvvidconv top=0 bottom=224 left=87 right=311 flip-method=2 ! video/x-raw, width=(int)224, height=(int)224, format=(string)BGRx ! videoconvert ! video/x-raw, format=(string)BGR ! appsink\n"
     ]
    }
   ],
   "source": [
    "camera_aspect_ratio = 16./9.\n",
    "INPUT_IMG_SIZE = 224\n",
    "\n",
    "gstream_pipeline = (\n",
    "    \"nvarguscamerasrc ! \"\n",
    "    \"video/x-raw(memory:NVMM), \"\n",
    "    \"width=(int){capture_width:d}, height=(int){capture_height:d}, \"\n",
    "    \"format=(string)NV12, framerate=(fraction){framerate:d}/1 ! \"\n",
    "    \"nvvidconv top={crop_top:d} bottom={crop_bottom:d} left={crop_left:d} right={crop_right:d} flip-method={flip_method:d} ! \"\n",
    "    \"video/x-raw, width=(int){display_width:d}, height=(int){display_height:d}, format=(string)BGRx ! \"\n",
    "    \"videoconvert ! \"\n",
    "    \"video/x-raw, format=(string)BGR ! appsink\".format(\n",
    "        capture_width = int(INPUT_IMG_SIZE*camera_aspect_ratio),\n",
    "        capture_height = INPUT_IMG_SIZE,\n",
    "        framerate = 60,\n",
    "        crop_top = 0,\n",
    "        crop_bottom = INPUT_IMG_SIZE,\n",
    "        crop_left = int(INPUT_IMG_SIZE*(camera_aspect_ratio-1)/2),\n",
    "        crop_right = int(INPUT_IMG_SIZE*(camera_aspect_ratio+1)/2),\n",
    "        flip_method = 2,\n",
    "        display_width = INPUT_IMG_SIZE,\n",
    "        display_height = INPUT_IMG_SIZE,\n",
    "    )\n",
    ")\n",
    "print(gstream_pipeline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d862f802-f0ec-4edf-989a-daa7969424fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# bufferless VideoCapture\n",
    "import cv2, queue, threading, time\n",
    "\n",
    "class VideoCapture:\n",
    "    def __init__(self, stream, apiPreference):\n",
    "        self.cap = cv2.VideoCapture(stream, apiPreference)\n",
    "        self.q = queue.Queue()\n",
    "        t = threading.Thread(target=self._reader)\n",
    "        t.daemon = True\n",
    "        t.start()\n",
    "\n",
    "    # read frames as soon as they are available, keeping only most recent one\n",
    "    def _reader(self):\n",
    "        while True:\n",
    "            ret, frame = self.cap.read()\n",
    "            if not ret:\n",
    "                break\n",
    "            if not self.q.empty():\n",
    "                try:\n",
    "                    self.q.get_nowait()   # discard previous (unprocessed) frame\n",
    "                except queue.Empty:\n",
    "                    pass\n",
    "            self.q.put((ret, frame))\n",
    "\n",
    "    def read(self):\n",
    "        return self.q.get()\n",
    "    \n",
    "    def release(self):\n",
    "        return self.cap.release()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "713e4833-4e8a-4883-a21b-38eb36ec392e",
   "metadata": {},
   "source": [
    "## Load image in the TensorRT pipeline\n",
    "\n",
    "Next, let's define a function that will preprocess the image, which is originally in HWC/BGR8 format. It is formated, normalized and loaded in the CUDA processing pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b3b4edef-b585-4a5e-a6a7-c150e3a026ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import torchvision.transforms as transforms\n",
    "import PIL.Image\n",
    "\n",
    "# Normalization values can be fine-tuned for your camera. Still, default values generally perform well.\n",
    "mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()\n",
    "std = torch.Tensor([0.229, 0.224, 0.225]).cuda()\n",
    "device = torch.device('cuda')\n",
    "\n",
    "def preprocess(image):\n",
    "    global device\n",
    "    device = torch.device('cuda')\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "    image = PIL.Image.fromarray(image)\n",
    "    image = transforms.functional.to_tensor(image).to(device)\n",
    "    image.sub_(mean[:, None, None]).div_(std[:, None, None])\n",
    "    return image[None, ...]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b8a8e0c-cb68-4cdc-aa65-c7f298c7a923",
   "metadata": {},
   "source": [
    "## Pose estimation inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9e30b7fd-e7d5-4afc-b9a8-fe8f3c38dcbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_keypoints(counts, objects, peak, indexBody=0):\n",
    "    #if indexBody<counts[0]:\n",
    "    #    return None\n",
    "    kpoint = []\n",
    "    human = objects[0][indexBody]\n",
    "    C = human.shape[0]\n",
    "    for j in range(C):\n",
    "        k = int(human[j])\n",
    "        if k >= 0:\n",
    "            peak = peaks[0][j][k]   # peak[1]:width, peak[0]:height\n",
    "            kpoint.append([float(peak[1]),float(peak[0])])\n",
    "        else:        \n",
    "            kpoint.append([None, None])\n",
    "    return np.array(kpoint)\n",
    "\n",
    "def get_cmap_paf(image):\n",
    "        data = preprocess(image)\n",
    "        cmap, paf = model_estimation(data)\n",
    "        cmap, paf = cmap.detach().cpu(), paf.detach().cpu()\n",
    "        return cmap, paf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfc8bb7b-e050-4752-b9ad-0f1df82e7161",
   "metadata": {},
   "source": [
    "## Keypoints pre-processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f4fbdc68-5cc1-40c5-a86b-9887dd3a1c2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def getLengthLimb(data, keypoint1: int, keypoint2: int):\n",
    "    if type(data[keypoint1, 0]) != type(None) and type(data[keypoint2, 0]) != type(None):\n",
    "        return np.linalg.norm([data[keypoint1, 0:2] - data[keypoint2, 0:2]])\n",
    "    return 0\n",
    "\n",
    "def preprocess_keypoints(keypoints:np.ndarray):\n",
    "    if type(keypoints) != type(None):\n",
    "        assert keypoints.shape == (18,2)\n",
    "        # Find bounding box\n",
    "        min_x, max_x = float(\"inf\"), 0.0\n",
    "        min_y, max_y = float(\"inf\"), 0.0\n",
    "        for k in keypoints:\n",
    "            if type(k[0]) != type(None):  # If keypoint exists\n",
    "                min_x = min(min_x, k[0])\n",
    "                max_x = max(max_x, k[0])\n",
    "                min_y = min(min_y, k[1])\n",
    "                max_y = max(max_y, k[1])\n",
    "\n",
    "        # Centering\n",
    "        np.subtract(\n",
    "            keypoints[:, 0],\n",
    "            (min_x + max_x) / 2.,\n",
    "            where=keypoints[:, 0] != None,\n",
    "            out=keypoints[:, 0],\n",
    "        )\n",
    "        np.subtract(\n",
    "            (min_y + max_y) / 2.,\n",
    "            keypoints[:, 1],\n",
    "            where=keypoints[:, 0] != None,\n",
    "            out=keypoints[:, 1],\n",
    "        )\n",
    "\n",
    "        # Scaling  \n",
    "        normalizedPartsLength = np.array(\n",
    "            [\n",
    "                getLengthLimb(keypoints, 6, 12) * (16.0 / 5.2),  # Torso right\n",
    "                getLengthLimb(keypoints, 5, 11) * (16.0 / 5.2),  # Torso left\n",
    "                getLengthLimb(keypoints, 0, 17) * (16.0 / 2.5),  # Neck\n",
    "                getLengthLimb(keypoints, 12, 14) * (16.0 / 3.6),  # Right thigh\n",
    "                getLengthLimb(keypoints, 14, 16) * (16.0 / 3.5),  # Right lower leg\n",
    "                getLengthLimb(keypoints, 11, 13) * (16.0 / 3.6),  # Left thigh\n",
    "                getLengthLimb(keypoints, 13, 15) * (16.0 / 3.5),  # Left lower leg\n",
    "            ]\n",
    "        )\n",
    "        \n",
    "        # Mean of non-zero lengths\n",
    "        normalizedPartsLength = normalizedPartsLength[normalizedPartsLength > 0.0]\n",
    "        if len(normalizedPartsLength)>0:\n",
    "            scaleFactor = np.mean(normalizedPartsLength)\n",
    "        else:\n",
    "            return None\n",
    "\n",
    "        # Populate None keypoints with 0s\n",
    "        keypoints[keypoints == None] = 0.0\n",
    "\n",
    "        # Normalize\n",
    "        np.divide(keypoints, scaleFactor, out=keypoints[:, 0:2])\n",
    "\n",
    "        if np.any((keypoints > 1.0) | (keypoints < -1.0)):\n",
    "            #print(\"Scaling error\")\n",
    "            return None\n",
    "\n",
    "        return keypoints.astype('float32')\n",
    "    else: return None"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9724a50-0c82-4e0e-8278-80a75a93af1b",
   "metadata": {},
   "source": [
    "import ipywidgets\n",
    "from IPython.display import display\n",
    "image_w = ipywidgets.Image(format='jpeg')\n",
    "display(image_w)## Processing loop\n",
    "\n",
    "- Read image\n",
    "- Pre-process to Torch format\n",
    "- Infere key-points\n",
    "- Draw skeleton on the input image\n",
    "- Update in output window."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d2fa4287-ff0a-4386-96ac-4633ec1048b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47e5a0532d914721b23e825fb431df01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Image(value=b'', format='jpeg')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ipywidgets\n",
    "from IPython.display import display\n",
    "image_w = ipywidgets.Image(format='jpeg')\n",
    "display(image_w)\n",
    "\n",
    "def show_image(img, label:str=None, fps:float=None):\n",
    "    if label:\n",
    "        label = label.replace('_', ' ')\n",
    "        img = cv2.putText(img, label, (10,25), cv2.FONT_HERSHEY_DUPLEX, .7, (0,), 2, cv2.LINE_AA)\n",
    "        img = cv2.putText(img, label, (10,25), cv2.FONT_HERSHEY_DUPLEX, .7, (255,255,255), 1, cv2.LINE_AA)\n",
    "    if fps:\n",
    "        fps = 'FPS: {:.2f}'.format(fps)\n",
    "        img = cv2.putText(img, fps, (10,img.shape[0]-10), cv2.FONT_HERSHEY_DUPLEX, .7, (0,0,0), 2, cv2.LINE_AA)\n",
    "        img = cv2.putText(img, fps, (10,img.shape[0]-10), cv2.FONT_HERSHEY_DUPLEX, .7, (255,255,255), 1, cv2.LINE_AA)\n",
    "    image_w.value = bytes(cv2.imencode('.jpg', img)[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7445e59f-e997-423a-88a8-0c650c5685c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Video processing stopped\n"
     ]
    }
   ],
   "source": [
    "from collections import deque\n",
    "import time\n",
    "\n",
    "buffer_size = 5\n",
    "processing_times = deque(buffer_size*[0.], buffer_size)\n",
    "\n",
    "cap_CSI = VideoCapture(gstream_pipeline, cv2.CAP_GSTREAMER)\n",
    "\n",
    "try:\n",
    "    while True:\n",
    "        start_time = time.time()\n",
    "            \n",
    "        # Get image\n",
    "        re, image = cap_CSI.read()\n",
    "        \n",
    "        if re:\n",
    "            \n",
    "            # TRT-Pose inference\n",
    "            cmap, paf = get_cmap_paf(image) # Pose estimation inference\n",
    "            counts, objects, peaks = parse_objects(cmap, paf) # Matching algorithm\n",
    "            keypoints = get_keypoints(counts, objects, peaks) # BODY18 model formating\n",
    "\n",
    "            # Classification inference\n",
    "            label_pose = None\n",
    "            keypoints = preprocess_keypoints(keypoints)\n",
    "            if type(keypoints) != type(None):\n",
    "                x = tf.constant(np.expand_dims(keypoints, axis=0), dtype=tf.float32)\n",
    "                prediction = infer_classification(x)\n",
    "                label_pose = classificationLabels[np.argmax(prediction['dense_20'][0])]\n",
    "\n",
    "            # Display image locally\n",
    "            draw_objects(image, counts, objects, peaks)\n",
    "            show_image(image, label_pose, np.mean(processing_times))\n",
    "        else:\n",
    "            raise RuntimeError('Could not read image from camera')\n",
    "            \n",
    "        processing_times.appendleft(1./(time.time() - start_time))\n",
    "            \n",
    "except (KeyboardInterrupt, RuntimeError) as e:\n",
    "    cap_CSI.release()\n",
    "    print('Video processing stopped')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ca10a40-3aea-46c5-8207-8908336a69b5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
